from torch import nn

class LeNet5(nn.Module):

    def __init__(self, num_classes, in_channels=3):
        super(LeNet5, self).__init__()
        
        self.feature_extractor = nn.Sequential(            
            nn.Conv2d(in_channels=in_channels, out_channels=6, kernel_size=5, stride=1), # 28x28
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2), # 14x14
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1), # 10x10
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2), #5x5
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
            nn.Tanh()
        )

        self.classifier = nn.Sequential(
            nn.Flatten(1),
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=num_classes),
        )


    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.classifier(x)
        return x